{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QWwStSaZAK70"
   },
   "outputs": [],
   "source": [
    "pip install pytorch_lightning --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lKHqRaW6_ree"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import pickle\n",
    "import nltk\n",
    "from PIL import Image\n",
    "\n",
    "import torch\n",
    "import torch.utils.data as data\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from model import HybridModel\n",
    "from vocabulary import Vocabulary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jXU6Ulg9BoNA"
   },
   "outputs": [],
   "source": [
    "nltk.download('punkt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3T_HKi7O_n4A"
   },
   "outputs": [],
   "source": [
    "class CocoDataset(data.Dataset):\n",
    "    def __init__(self, data_path, json_path, vocabulary, transform=None):\n",
    "        self.image_dir = data_path\n",
    "        self.vocabulary = vocabulary\n",
    "        self.transform = transform\n",
    "        with open(json_path) as json_file:\n",
    "            self.coco = json.load(json_file)\n",
    "        self.image_id_file_name = dict()\n",
    "        for image in self.coco['images']:\n",
    "            self.image_id_file_name[image['id']] = image['file_name']\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        annotation = self.coco['annotations'][idx]\n",
    "        caption = annotation['caption']\n",
    "        tkns = nltk.tokenize.word_tokenize(str(caption).lower())\n",
    "        caption = []\n",
    "        caption.append(self.vocabulary('<start>'))\n",
    "        caption.extend([self.vocabulary(tkn) for tkn in tkns])\n",
    "        caption.append(self.vocabulary('<end>'))\n",
    "\n",
    "        image_id = annotation['image_id']\n",
    "        image_file = self.image_id_file_name[image_id]\n",
    "        image_path = os.path.join(self.image_dir, image_file)\n",
    "        image = Image.open(image_path).convert('RGB')\n",
    "        if self.transform is not None:\n",
    "            image = self.transform(image)\n",
    "\n",
    "        return image, torch.Tensor(caption)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.coco['annotations'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KyhtWATq_4DG"
   },
   "outputs": [],
   "source": [
    "def coco_collate_fn(data_batch):\n",
    "    data_batch.sort(key=lambda d: len(d[1]), reverse=True)\n",
    "    imgs, caps = zip(*data_batch)\n",
    "\n",
    "    imgs = torch.stack(imgs, 0)\n",
    "\n",
    "    cap_lens = [len(cap) for cap in caps]\n",
    "    padded_caps = torch.zeros(len(caps), max(cap_lens)).long()\n",
    "    for i, cap in enumerate(caps):\n",
    "        end = cap_lens[i]\n",
    "        padded_caps[i, :end] = cap[:end]\n",
    "    return imgs, padded_caps, cap_lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EBarWLoi_9pR"
   },
   "outputs": [],
   "source": [
    "def get_loader(data_path, json_path, vocabulary, transform, batch_size, shuffle, num_workers=0):\n",
    "    coco_ds = CocoDataset(data_path=data_path,\n",
    "                          json_path=json_path,\n",
    "                          vocabulary=vocabulary,\n",
    "                          transform=transform)\n",
    "    coco_dl = data.DataLoader(dataset=coco_ds,\n",
    "                            batch_size=batch_size,\n",
    "                            shuffle=shuffle,\n",
    "                            num_workers=num_workers,\n",
    "                            collate_fn=coco_collate_fn)\n",
    "    return coco_dl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nb6IPGlgAbcr"
   },
   "outputs": [],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.RandomCrop(224),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.485, 0.456, 0.406),\n",
    "                          (0.229, 0.224, 0.225))])\n",
    "\n",
    "with open('coco_data/vocabulary.pkl', 'rb') as f:\n",
    "    vocabulary = pickle.load(f)\n",
    "\n",
    "coco_data_loader = get_loader('coco_data/images',\n",
    "                              'coco_data/captions.json',\n",
    "                              vocabulary,\n",
    "                              transform,\n",
    "                              128,\n",
    "                              shuffle=True,\n",
    "                              num_workers=4)\n",
    "\n",
    "hybrid_model = HybridModel(256, 256, 512,\n",
    "                           len(vocabulary), 1)\n",
    "trainer = pl.Trainer(max_epochs=5)\n",
    "trainer.fit(hybrid_model, coco_data_loader)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "machine_shape": "hm",
   "name": "train_model.ipynb",
   "private_outputs": true,
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
